import matplotlib.pyplot as plt
import torch
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"



col = ['dodgerblue', "tab:orange", "mediumaquamarine", 'lightcoral', 'skyblue', 'sandybrown']

color_noise = 'tab:blue'
color_feature = 'lightcoral'
color_loss = 'tab:blue'
color_acc_train = 'tab:blue'
color_acc_test = 'lightcoral'

########################## mu = 5 #################################################
checkpoint_class = torch.load('syn_class_5.pth')
checkpoint_diff = torch.load('syn_diff_0.2_5.pth')

noise_seris = checkpoint_diff['noise']
feature_seris1 =checkpoint_diff['feature1']
feature_seris2 =checkpoint_diff['feature2']



# diffusion feature learning
plt.subplots(figsize=(7,6))

plt.plot((noise_seris[0].T), label = r'$\max_{r} |\langle w_{r}, \xi_i\rangle|$', color=color_noise, linewidth=3)
plt.plot((noise_seris[1:].T), color=color_noise, linewidth=3 )
plt.plot((feature_seris1), label = r'$\max_{r}|\langle {w}_{r}, \boldsymbol{\mu_{1}} \rangle|$' , color = color_feature, linewidth=3)
plt.plot((feature_seris2), label = r'$\max_{r}|\langle {w}_{r}, \boldsymbol{\mu_{-1}} \rangle|$' , color = color_feature, linestyle='--', linewidth=3)

plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
plt.title('Diffusion model', fontsize=25)
plt.legend(fontsize=22)
plt.tight_layout()
plt.savefig('figures/syn_diff_feat_learn.png', dpi=300)

# diffusion first stage
plt.subplots(figsize=(7,6))

t_max = 1000
plt.plot((noise_seris[0].T[:t_max]),  label = r'$\max_{r} |\langle w_{r}, \xi_i\rangle|$' , color=color_noise, linewidth=3)
plt.plot((noise_seris[1:].T[:t_max]), color=color_noise, linewidth = 3 )
plt.plot((feature_seris1[:t_max]), label = r'$\max_{r}|\langle {w}_{r}, \boldsymbol{\mu_{1}} \rangle|$' , color =  color_feature, linewidth=3)
plt.plot((feature_seris2[:t_max]), label = r'$\max_{r}|\langle {w}_{r}, \boldsymbol{\mu_{-1}} \rangle|$' , color = color_feature, linestyle='--', linewidth=3)

plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
plt.title('Diffusion model', fontsize=25)
plt.legend(fontsize=22)
plt.tight_layout()
plt.savefig('figures/syn_diff_first_stage.png', dpi=300)

#diffusion loss
plt.subplots(figsize=(7,6))

plt.plot(checkpoint_diff['loss'], color = color_loss, linewidth=3 )

plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
plt.title('Diffusion model', fontsize=25)
# plt.ylabel('Loss', fontsize=25)
plt.tight_layout()
plt.savefig('figures/syn_diff_loss.png', dpi=300)



noise_seris = checkpoint_class['noise']
feature_seris1 =checkpoint_class['feature1']
feature_seris2 =checkpoint_class['feature2']


# class feature learning
plt.subplots(figsize=(7,6))

plt.plot((noise_seris[0].T),  label = r'$\max_{j,r} |\langle w_{j,r}, \xi_i\rangle|$' , color=color_noise, linewidth=3)
plt.plot((noise_seris[1:].T),  color=color_noise, linewidth=3)
plt.plot((feature_seris1), label = r'$\max_{j,r}|\langle {w}_{j,r}, \boldsymbol{\mu_{1}} \rangle|$' , color = color_feature, linewidth=3)
plt.plot((feature_seris2), label = r'$\max_{j,r}|\langle {w}_{j,r}, \boldsymbol{\mu_{-1}} \rangle|$' , color = color_feature, linewidth=3, linestyle='--')

plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
plt.title('Classification', fontsize=25)
plt.legend(fontsize=22)
plt.tight_layout()
plt.savefig('figures/syn_class_feat_learn.png', dpi=300)



# class loss
plt.subplots(figsize=(7,6))

plt.plot(checkpoint_class['loss'], color = color_loss, linewidth=3)
plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
plt.title('Classification', fontsize=25)
# plt.ylabel('Loss', fontsize=25)
plt.tight_layout()
plt.savefig('figures/syn_class_loss.png', dpi=300)


# class acc
plt.subplots(figsize=(7,6))

plt.plot(checkpoint_class['train_acc'], label='Train acc', color = color_acc_train, linewidth=3)
plt.plot(checkpoint_class['test_acc'], label = 'Test acc', color=color_acc_test, linewidth=3)
plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
# plt.ylabel('Loss', fontsize=25)
plt.title('Classification', fontsize=25)
plt.legend(fontsize=22)
plt.tight_layout()
plt.savefig('figures/syn_class_acc.png', dpi=300)






######################## mu = 15 #######################################


checkpoint_class = torch.load('syn_class_15.pth')
checkpoint_diff = torch.load('syn_diff_0.2_15.pth')

noise_seris = checkpoint_diff['noise']
feature_seris1 =checkpoint_diff['feature1']
feature_seris2 =checkpoint_diff['feature2']



# diffusion feature learning
plt.subplots(figsize=(7,6))

plt.plot((noise_seris[0].T),  label = r'$\max_{r} |\langle w_{r}, \xi_i\rangle|$', color=color_noise, linewidth=3)
plt.plot((noise_seris[1:].T), color=color_noise, linewidth=3)
plt.plot((feature_seris1), label = r'$\max_{r}|\langle {w}_{r}, \boldsymbol{\mu_{1}} \rangle|$' , color = color_feature, linewidth=3)
plt.plot((feature_seris2), label = r'$\max_{r}|\langle {w}_{r}, \boldsymbol{\mu_{-1}} \rangle|$' , color = color_feature, linewidth=3, linestyle='--')

plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
plt.title('Diffusion model', fontsize=25)
plt.legend(fontsize=22)
plt.tight_layout()
plt.savefig('figures/syn_diff_feat_learn_15.png', dpi=300)

# diffusion first stage
plt.subplots(figsize=(7,6))

t_max = 1000
plt.plot((noise_seris[0].T[:t_max]), color=color_noise, linewidth=3, label = r'$\max_{r} |\langle w_{r}, \xi_i\rangle|$' )
plt.plot((noise_seris[1:].T[:t_max]), color=color_noise, linewidth=3 )
plt.plot((feature_seris1[:t_max]), label = r'$\max_{r}|\langle {w}_{r}, \boldsymbol{\mu_{1}} \rangle|$' , color=color_feature, linewidth=3)
plt.plot((feature_seris2[:t_max]), label = r'$\max_{r}|\langle {w}_{r}, \boldsymbol{\mu_{-1}} \rangle|$' , color=color_feature, linewidth=3, linestyle='--')

plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
plt.title('Diffusion model', fontsize=25)
plt.legend(fontsize=22)
plt.tight_layout()
plt.savefig('figures/syn_diff_first_stage_15.png', dpi=300)

#diffusion loss
plt.subplots(figsize=(7,6))

plt.plot(checkpoint_diff['loss'], color = color_loss, linewidth=3)

plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
plt.title('Diffusion model', fontsize=25)
# plt.ylabel('Loss', fontsize=25)
plt.tight_layout()
plt.savefig('figures/syn_diff_loss_15.png', dpi=300)



noise_seris = checkpoint_class['noise']
feature_seris1 =checkpoint_class['feature1']
feature_seris2 =checkpoint_class['feature2']


# class feature learning
plt.subplots(figsize=(7,6))

plt.plot((noise_seris[0].T), color=color_noise, linewidth=3, label=r'$\max_{j,r} |\langle w_{j,r}, \xi_i\rangle|$' )
plt.plot((noise_seris[1:].T), color=color_noise, linewidth=3)
plt.plot((feature_seris1), label = r'$\max_{j,r}|\langle {w}_{j,r}, \boldsymbol{\mu_{1}} \rangle|$' , color=color_feature, linewidth=3)
plt.plot((feature_seris2), label = r'$\max_{j,r}|\langle {w}_{j,r}, \boldsymbol{\mu_{-1}} \rangle|$' , color=color_feature, linewidth=3, linestyle='--')

plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
plt.title('Classification', fontsize=25)
plt.legend(fontsize=22)
plt.tight_layout()
plt.savefig('figures/syn_class_feat_learn_15.png', dpi=300)



# class loss
plt.subplots(figsize=(7,6))

plt.plot(checkpoint_class['loss'], color = color_loss, linewidth=3 )
plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
plt.title('Classification', fontsize=25)
# plt.ylabel('Loss', fontsize=25)
plt.tight_layout()
plt.savefig('figures/syn_class_loss_15.png', dpi=300)


# class acc
plt.subplots(figsize=(7,6))

plt.plot(checkpoint_class['train_acc'], color = color_acc_train, label='Train acc', linewidth=3)
plt.plot(checkpoint_class['test_acc'], color = color_acc_test, label = 'Test acc', linewidth=3)
plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
plt.tick_params(axis='both', which='minor', labelsize=15)
plt.xlabel('Iteration', fontsize=25)
# plt.ylabel('Loss', fontsize=25)
plt.title('Classification', fontsize=25)
plt.legend(fontsize=22)
plt.tight_layout()
plt.savefig('figures/syn_class_acc_15.png', dpi=300)